{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0da2800a-4e53-485a-bc78-f1d178c31207",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eb15eb73-9b5d-4279-b3f4-63a54393acfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['font.sans-serif'] = 'SimHei'\n",
    "plt.rcParams['axes.unicode_minus'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "712fe544-97fd-46c7-8f34-05ac3263ac20",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fde8660a-3774-474a-a14e-7b949a89e7c9",
   "metadata": {},
   "source": [
    "## 年收入大于50K二分类"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b4bdd24-a84a-429f-ac06-6abc061d270f",
   "metadata": {},
   "source": [
    "### 1、数据获取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b87bfbb7-e828-47ba-848d-cad1b38f4c7c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>final_weight</th>\n",
       "      <th>education</th>\n",
       "      <th>education_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "      <th>salary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>Cuba</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  final_weight  education  education_num  \\\n",
       "0   39         State-gov         77516  Bachelors             13   \n",
       "1   50  Self-emp-not-inc         83311  Bachelors             13   \n",
       "2   38           Private        215646    HS-grad              9   \n",
       "3   53           Private        234721       11th              7   \n",
       "4   28           Private        338409  Bachelors             13   \n",
       "\n",
       "       marital_status         occupation   relationship   race     sex  \\\n",
       "0       Never-married       Adm-clerical  Not-in-family  White    Male   \n",
       "1  Married-civ-spouse    Exec-managerial        Husband  White    Male   \n",
       "2            Divorced  Handlers-cleaners  Not-in-family  White    Male   \n",
       "3  Married-civ-spouse  Handlers-cleaners        Husband  Black    Male   \n",
       "4  Married-civ-spouse     Prof-specialty           Wife  Black  Female   \n",
       "\n",
       "   capital_gain  capital_loss  hours_per_week native_country salary  \n",
       "0          2174             0              40  United-States  <=50K  \n",
       "1             0             0              13  United-States  <=50K  \n",
       "2             0             0              40  United-States  <=50K  \n",
       "3             0             0              40  United-States  <=50K  \n",
       "4             0             0              40           Cuba  <=50K  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 数据事先采集，直接从dataset文件夹中导入\n",
    "df = pd.read_csv('../dataset/adults.txt', sep=',')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a62e3612-4d4a-4204-9963-60d40f9ae820",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        <=50K\n",
       "1        <=50K\n",
       "2        <=50K\n",
       "3        <=50K\n",
       "4        <=50K\n",
       "         ...  \n",
       "32556    <=50K\n",
       "32557     >50K\n",
       "32558    <=50K\n",
       "32559    <=50K\n",
       "32560     >50K\n",
       "Name: salary, Length: 32561, dtype: object"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# salary这一列是标签\n",
    "df['salary']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57df44f8-49c9-43a7-9cd1-4870b360edeb",
   "metadata": {},
   "source": [
    "### 2、数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "71733823-a8c1-4adc-9d56-a6a569fcf859",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = df.iloc[:, :-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "95e4fb7a-2d67-4efd-b49a-6e0b8a0aaf2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 32561 entries, 0 to 32560\n",
      "Data columns (total 14 columns):\n",
      " #   Column          Non-Null Count  Dtype \n",
      "---  ------          --------------  ----- \n",
      " 0   age             32561 non-null  int64 \n",
      " 1   workclass       32561 non-null  object\n",
      " 2   final_weight    32561 non-null  int64 \n",
      " 3   education       32561 non-null  object\n",
      " 4   education_num   32561 non-null  int64 \n",
      " 5   marital_status  32561 non-null  object\n",
      " 6   occupation      32561 non-null  object\n",
      " 7   relationship    32561 non-null  object\n",
      " 8   race            32561 non-null  object\n",
      " 9   sex             32561 non-null  object\n",
      " 10  capital_gain    32561 non-null  int64 \n",
      " 11  capital_loss    32561 non-null  int64 \n",
      " 12  hours_per_week  32561 non-null  int64 \n",
      " 13  native_country  32561 non-null  object\n",
      "dtypes: int64(6), object(8)\n",
      "memory usage: 3.5+ MB\n"
     ]
    }
   ],
   "source": [
    "data.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "41c5c3d1-8b17-431a-882a-00cdf887c04a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['State-gov', 'Self-emp-not-inc', 'Private', 'Federal-gov',\n",
       "       'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 获取'workclass'列中的唯一值\n",
    "unique_values = data['workclass'].unique()\n",
    "unique_values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08faad8f-ea46-47f6-ac14-a1fbab36e87e",
   "metadata": {},
   "source": [
    "#### ① 对字符串类型的列进行编码：方法一"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fac0ea00-20af-49ed-a4a0-af7b1a37f539",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>final_weight</th>\n",
       "      <th>education</th>\n",
       "      <th>education_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>Cuba</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  final_weight  education  education_num  \\\n",
       "0   39         State-gov         77516  Bachelors             13   \n",
       "1   50  Self-emp-not-inc         83311  Bachelors             13   \n",
       "2   38           Private        215646    HS-grad              9   \n",
       "3   53           Private        234721       11th              7   \n",
       "4   28           Private        338409  Bachelors             13   \n",
       "\n",
       "       marital_status         occupation   relationship   race     sex  \\\n",
       "0       Never-married       Adm-clerical  Not-in-family  White    Male   \n",
       "1  Married-civ-spouse    Exec-managerial        Husband  White    Male   \n",
       "2            Divorced  Handlers-cleaners  Not-in-family  White    Male   \n",
       "3  Married-civ-spouse  Handlers-cleaners        Husband  Black    Male   \n",
       "4  Married-civ-spouse     Prof-specialty           Wife  Black  Female   \n",
       "\n",
       "   capital_gain  capital_loss  hours_per_week native_country  \n",
       "0          2174             0              40  United-States  \n",
       "1             0             0              13  United-States  \n",
       "2             0             0              40  United-States  \n",
       "3             0             0              40  United-States  \n",
       "4             0             0              40           Cuba  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 先查看一下原始的data\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "559a502f-737d-480f-953f-51f663b09b0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cols = ['workclass', 'education', 'marital_status', 'occupation', 'relationship', 'race', 'sex', 'native_country']\n",
    "# 从数据集中筛选出所有数据类型为'object'的列，并将这些列的名称存储在列表cols中\n",
    "cols = [x for x in data.columns if data[x].dtype == 'object']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1c6a4081-e0df-41b1-ae47-74774d91a1db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['age', 'workclass', 'final_weight', 'education', 'education_num',\n",
       "       'marital_status', 'occupation', 'relationship', 'race', 'sex',\n",
       "       'capital_gain', 'capital_loss', 'hours_per_week', 'native_country'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8880f2c9-b5d0-412d-b66f-144280ec67dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['State-gov', 'Self-emp-not-inc', 'Private', 'Federal-gov',\n",
       "       'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data['workclass'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "105917d2-72ee-45c8-815d-9608bdf1a238",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2]], dtype=int64)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.argwhere(data['workclass'].unique()=='Private')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "8ac43f3a-33a9-43cf-bc4e-d22fe23c835d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个函数convert，输入参数为x\n",
    "def convert(x):\n",
    "    # 使用numpy的argwhere函数找到unique中等于x的元素的索引，并返回第一个元素的索引值\n",
    "    return np.argwhere(unique==x)[0][0]\n",
    "\n",
    "# 遍历cols列表中的每个元素\n",
    "for col in cols:\n",
    "    # 获取data[col]列中的唯一值，并将其赋值给unique变量\n",
    "    unique = data[col].unique()\n",
    "    # 使用map函数将data[col]列中的每个元素通过convert函数进行转换，并将结果赋值回data[col]列\n",
    "    data[col] = data[col].map(convert)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7ec4acd9-4733-430f-81bc-f441294f4820",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>final_weight</th>\n",
       "      <th>education</th>\n",
       "      <th>education_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>0</td>\n",
       "      <td>77516</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>1</td>\n",
       "      <td>83311</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>2</td>\n",
       "      <td>215646</td>\n",
       "      <td>1</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>2</td>\n",
       "      <td>234721</td>\n",
       "      <td>2</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>2</td>\n",
       "      <td>338409</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  final_weight  education  education_num  marital_status  \\\n",
       "0   39          0         77516          0             13               0   \n",
       "1   50          1         83311          0             13               1   \n",
       "2   38          2        215646          1              9               2   \n",
       "3   53          2        234721          2              7               1   \n",
       "4   28          2        338409          0             13               1   \n",
       "\n",
       "   occupation  relationship  race  sex  capital_gain  capital_loss  \\\n",
       "0           0             0     0    0          2174             0   \n",
       "1           1             1     0    0             0             0   \n",
       "2           2             0     0    0             0             0   \n",
       "3           2             1     1    0             0             0   \n",
       "4           3             2     1    1             0             0   \n",
       "\n",
       "   hours_per_week  native_country  \n",
       "0              40               0  \n",
       "1              13               0  \n",
       "2              40               0  \n",
       "3              40               0  \n",
       "4              40               1  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "938eff0a-9280-4cd5-b85e-88f2a08d6da4",
   "metadata": {},
   "source": [
    "#### ② 对字符串类型的列进行编码：方法二"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ce37055-b89b-48c6-b469-7f722e9295ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对无法转换成float的列进行编码\n",
    "from sklearn.preprocessing import LabelEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0afe59f3-d77e-4773-be0f-780ce70da953",
   "metadata": {},
   "outputs": [],
   "source": [
    "for col in cols:\n",
    "    # 创建一个LabelEncoder对象\n",
    "    le = LabelEncoder()\n",
    "    # 使用LabelEncoder对象对data['workclass']列进行编码，并将编码后的结果赋值回原列\n",
    "    data[col] = le.fit_transform(data[col])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "937d4b85-3c47-4f28-8fd2-9af52b252775",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fb164d4-81cd-4a4c-9b64-ad80a29b577f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def salary_func(x):\n",
    "    if x == '>50K':\n",
    "        return 1\n",
    "    else:\n",
    "        return 0\n",
    "df['salary'] = df['salary'].apply(salary_func)\n",
    "df['salary'].sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ff556f4-8052-498f-948b-f00a0ba69538",
   "metadata": {},
   "source": [
    "### 3、划分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "57c42af1-e4ed-4eea-90a0-85e29cce4c9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入sklearn库中的train_test_split函数，用于将数据集划分为训练集和测试集\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6e600b30-ec05-4818-81bd-e9a26307a1a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用train_test_split函数将数据集data的特征值和目标值salary进行划分，其中测试集占总数据的25%\n",
    "X_train, X_test, y_train, y_test = train_test_split(data.values, df['salary'].values, test_size=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8d0d1b2f-edae-45d8-9f92-88d8ef3c3088",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[    20,      2, 224238, ...,      0,     40,      0],\n",
       "       [    38,      2, 607848, ...,      0,     40,      0],\n",
       "       [    46,      2,  28419, ...,      0,     50,      0],\n",
       "       ...,\n",
       "       [    36,      0, 135874, ...,      0,     40,      0],\n",
       "       [    41,      2, 171234, ...,      0,     55,      0],\n",
       "       [    43,      2, 184378, ...,      0,     35,      0]], dtype=int64)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4092cae8-7fd0-456f-9730-68edc9d04d11",
   "metadata": {},
   "source": [
    "### 4、训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "e23c6de7-9082-49bc-8f8d-1e0010250a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "knn = KNeighborsClassifier()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5286f338-2fde-40ca-b3ad-4ad8721662a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KNeighborsClassifier()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">KNeighborsClassifier</label><div class=\"sk-toggleable__content\"><pre>KNeighborsClassifier()</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "KNeighborsClassifier()"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a10830a-be8b-469c-b9ce-350347d65233",
   "metadata": {},
   "source": [
    "### 5、模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1c830990-7f4a-43e4-9387-b49dae2db72e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用knn模型对X_test进行预测，将预测结果赋值给y_pred\n",
    "y_pred = knn.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6b0ebf2f-e969-4b12-b34f-0a22f96f644b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['>50K', '>50K', '>50K', ..., '<=50K', '<=50K', '>50K'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2d6938bc-419e-4960-8bc8-5e265c6f19f5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['>50K', '<=50K', '<=50K', ..., '<=50K', '<=50K', '>50K'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc7003bc-4230-4676-8593-e52ea8259734",
   "metadata": {},
   "source": [
    "### 6、模型评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "21220950-bb6e-452a-8f9f-f3f01940fa90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7812308070261639"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用KNN模型的score方法计算测试集的预测准确率\n",
    "accuracy = knn.score(X_test, y_test)\n",
    "accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "e63aa79e-f7e5-4fbb-bc21-714daba9b2cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 分类评估报告\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "595b4d0f-cc6a-46fb-a106-977f13610f01",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       <=50K       0.81      0.93      0.87      6173\n",
      "        >50K       0.58      0.33      0.42      1968\n",
      "\n",
      "    accuracy                           0.78      8141\n",
      "   macro avg       0.70      0.63      0.64      8141\n",
      "weighted avg       0.76      0.78      0.76      8141\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_true=y_test, y_pred=y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd44669-6906-401f-be3f-31d9cce262f7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
